# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test FusedScaleMaskSoftmax with various configurations"""
import numpy as np

def get_init_params(batch_size, num_heads, seq_length):
    """
    Generates initial parameters for FusedScaleMaskSoftmax.
    Input shape (batch_size, num_heads, seq_length, head_dim) which becomes (B, N, S, S) for attention scores.
    For softmax, let's assume input is (batch_size, num_heads, seq_length, seq_length_kv)
    User specified input as (4, 2, 4, 4) where last two are seq_len.
    """
    np.random.seed(42)
    inputs_np = np.random.rand(batch_size, num_heads, seq_length, seq_length).astype(np.float32)
    mask_np = np.random.choice([True, False], size=(seq_length, seq_length)).astype(np.int32)
    return {
        "inputs": inputs_np,
        "external_mask": mask_np,
    }

def get_golden() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    output_base = np.array(
        [[[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.50000000, 0.50000000, 0.00000000, 0.00000000],
           [0.37500000, 0.41601562, 0.20898438, 0.00000000],
           [0.38671875, 0.20800781, 0.20214844, 0.20214844]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.61718750, 0.38281250, 0.00000000, 0.00000000],
           [0.31640625, 0.43945312, 0.24414062, 0.00000000],
           [0.30859375, 0.17773438, 0.31250000, 0.20117188]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.55078125, 0.44921875, 0.00000000, 0.00000000],
           [0.29687500, 0.43164062, 0.27148438, 0.00000000],
           [0.20605469, 0.30859375, 0.21777344, 0.26757812]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.51171875, 0.48828125, 0.00000000, 0.00000000],
           [0.32617188, 0.36328125, 0.31250000, 0.00000000],
           [0.22656250, 0.20117188, 0.35156250, 0.21972656]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.28710938, 0.71484375, 0.00000000, 0.00000000],
           [0.18945312, 0.42773438, 0.38281250, 0.00000000],
           [0.37304688, 0.18554688, 0.24707031, 0.19433594]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.49609375, 0.50390625, 0.00000000, 0.00000000],
           [0.47070312, 0.31054688, 0.21875000, 0.00000000],
           [0.27929688, 0.22851562, 0.28125000, 0.21289062]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.35351562, 0.64843750, 0.00000000, 0.00000000],
           [0.47070312, 0.24316406, 0.28710938, 0.00000000],
           [0.25976562, 0.22265625, 0.27539062, 0.24218750]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.64843750, 0.34960938, 0.00000000, 0.00000000],
           [0.36914062, 0.40234375, 0.22656250, 0.00000000],
           [0.16894531, 0.20703125, 0.30468750, 0.31835938]]]])
    output_use_construct_mask = np.array(
        [[[[0.00000000, 0.55468750, 0.44531250, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.35742188, 0.17968750, 0.46289062],
           [0.00000000, 0.00000000, 0.50000000, 0.50000000]],
          [[0.00000000, 0.52343750, 0.47656250, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.43164062, 0.24023438, 0.33007812],
           [0.00000000, 0.00000000, 0.60937500, 0.39257812]]],
         [[[0.00000000, 0.49609375, 0.50390625, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.31835938, 0.20019531, 0.48046875],
           [0.00000000, 0.00000000, 0.44921875, 0.55078125]],
          [[0.00000000, 0.31250000, 0.68750000, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.33398438, 0.28710938, 0.37890625],
           [0.00000000, 0.00000000, 0.61718750, 0.38476562]]],
         [[[0.00000000, 0.59765625, 0.40039062, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.35546875, 0.31835938, 0.32617188],
           [0.00000000, 0.00000000, 0.56250000, 0.43945312]],
          [[0.00000000, 0.57421875, 0.42773438, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.33593750, 0.23632812, 0.42773438],
           [0.00000000, 0.00000000, 0.57031250, 0.43164062]]],
         [[[0.00000000, 0.59765625, 0.40039062, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.26171875, 0.30664062, 0.43164062],
           [0.00000000, 0.00000000, 0.53125000, 0.46875000]],
          [[0.00000000, 0.54296875, 0.45703125, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.49609375, 0.27929688, 0.22558594],
           [0.00000000, 0.00000000, 0.49023438, 0.51171875]]]])
    output_fp16 = np.array(
        [[[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.50000000, 0.50000000, 0.00000000, 0.00000000],
           [0.37426758, 0.41650390, 0.20935059, 0.00000000],
           [0.38720703, 0.20825195, 0.20202637, 0.20239258]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.61572266, 0.38403320, 0.00000000, 0.00000000],
           [0.31616210, 0.43920898, 0.24462890, 0.00000000],
           [0.30761720, 0.17822266, 0.31225586, 0.20178223]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.55175780, 0.44848633, 0.00000000, 0.00000000],
           [0.29687500, 0.43115234, 0.27197266, 0.00000000],
           [0.20617676, 0.30883790, 0.21740723, 0.26782227]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.51123047, 0.48901367, 0.00000000, 0.00000000],
           [0.32568360, 0.36254883, 0.31176758, 0.00000000],
           [0.22668457, 0.20166016, 0.35205078, 0.21960449]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.28662110, 0.71337890, 0.00000000, 0.00000000],
           [0.18994140, 0.42700195, 0.38305664, 0.00000000],
           [0.37329102, 0.18591309, 0.24694824, 0.19384766]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.49658203, 0.50341797, 0.00000000, 0.00000000],
           [0.47070312, 0.31079102, 0.21850586, 0.00000000],
           [0.27807617, 0.22778320, 0.28100586, 0.21301270]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.35327148, 0.64697266, 0.00000000, 0.00000000],
           [0.47045898, 0.24353027, 0.28613280, 0.00000000],
           [0.25927734, 0.22277832, 0.27563477, 0.24230957]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.64941406, 0.35034180, 0.00000000, 0.00000000],
           [0.36962890, 0.40380860, 0.22656250, 0.00000000],
           [0.16931152, 0.20654297, 0.30541992, 0.31884766]]]], dtype=np.float16)
    output_padding = np.array(
        [[[[0.18359375, 0.32617188, 0.26171875, 0.22949219],
           [0.20214844, 0.20214844, 0.18359375, 0.41210938],
           [0.24316406, 0.26953125, 0.13574219, 0.35156250],
           [0.38671875, 0.20800781, 0.20214844, 0.20214844]],
          [[0.22949219, 0.28515625, 0.25976562, 0.22558594],
           [0.32031250, 0.19921875, 0.23144531, 0.25000000],
           [0.23730469, 0.32812500, 0.18261719, 0.25195312],
           [0.30859375, 0.17773438, 0.31250000, 0.20117188]]],
         [[[0.12500000, 0.30273438, 0.30859375, 0.26367188],
           [0.22656250, 0.18359375, 0.33007812, 0.25976562],
           [0.17968750, 0.26171875, 0.16406250, 0.39453125],
           [0.20605469, 0.30859375, 0.21777344, 0.26757812]],
          [[0.22363281, 0.15527344, 0.33984375, 0.28125000],
           [0.27539062, 0.26171875, 0.19433594, 0.26953125],
           [0.23046875, 0.25585938, 0.22070312, 0.29296875],
           [0.22656250, 0.20117188, 0.35156250, 0.21972656]]],
         [[[0.20605469, 0.26757812, 0.17871094, 0.34765625],
           [0.15039062, 0.37500000, 0.30273438, 0.17089844],
           [0.13671875, 0.30664062, 0.27539062, 0.28125000],
           [0.37304688, 0.18554688, 0.24707031, 0.19433594]],
          [[0.35351562, 0.27929688, 0.20800781, 0.15917969],
           [0.20312500, 0.20605469, 0.30859375, 0.28125000],
           [0.33789062, 0.22265625, 0.15625000, 0.28320312],
           [0.27929688, 0.22851562, 0.28125000, 0.21289062]]],
         [[[0.31445312, 0.28515625, 0.19140625, 0.20800781],
           [0.17382812, 0.31835938, 0.23046875, 0.27929688],
           [0.33398438, 0.17382812, 0.20410156, 0.28710938],
           [0.25976562, 0.22265625, 0.27539062, 0.24218750]],
          [[0.27929688, 0.24804688, 0.20800781, 0.26367188],
           [0.29492188, 0.15917969, 0.32031250, 0.22558594],
           [0.31250000, 0.33984375, 0.19140625, 0.15527344],
           [0.16894531, 0.20703125, 0.30468750, 0.31835938]]]])

    output_scale_0_9 = np.array(
        [[[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.50000000, 0.50000000, 0.00000000, 0.00000000],
           [0.37109375, 0.40820312, 0.21972656, 0.00000000],
           [0.37304688, 0.21289062, 0.20703125, 0.20800781]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.60546875, 0.39453125, 0.00000000, 0.00000000],
           [0.31835938, 0.42773438, 0.25195312, 0.00000000],
           [0.30273438, 0.18457031, 0.30664062, 0.20703125]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.54687500, 0.45312500, 0.00000000, 0.00000000],
           [0.30078125, 0.42187500, 0.27734375, 0.00000000],
           [0.20996094, 0.30273438, 0.22070312, 0.26562500]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.51171875, 0.49023438, 0.00000000, 0.00000000],
           [0.32617188, 0.35937500, 0.31445312, 0.00000000],
           [0.22949219, 0.20605469, 0.34179688, 0.22265625]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.30468750, 0.69531250, 0.00000000, 0.00000000],
           [0.20214844, 0.41796875, 0.37890625, 0.00000000],
           [0.35937500, 0.19238281, 0.24902344, 0.19921875]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.49609375, 0.50390625, 0.00000000, 0.00000000],
           [0.45703125, 0.31445312, 0.22851562, 0.00000000],
           [0.27539062, 0.23046875, 0.27734375, 0.21679688]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.36718750, 0.63281250, 0.00000000, 0.00000000],
           [0.45507812, 0.25195312, 0.29101562, 0.00000000],
           [0.25781250, 0.22558594, 0.27343750, 0.24316406]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.63671875, 0.36523438, 0.00000000, 0.00000000],
           [0.36718750, 0.39648438, 0.23632812, 0.00000000],
           [0.17675781, 0.21191406, 0.30078125, 0.31250000]]]])
    return {
        "output_base": output_base,
        "output_fp16": output_fp16,
        "output_padding": output_padding,
        "output_use_construct_mask": output_use_construct_mask,
        "output_scale_0_9": output_scale_0_9,
    }

def get_gpu_datas() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    output_base = np.array(
        [[[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.50000000, 0.50000000, 0.00000000, 0.00000000],
           [0.37500000, 0.41601562, 0.20898438, 0.00000000],
           [0.38671875, 0.20800781, 0.20214844, 0.20214844]],

          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.61718750, 0.38281250, 0.00000000, 0.00000000],
           [0.31640625, 0.43945312, 0.24414062, 0.00000000],
           [0.30859375, 0.17773438, 0.31250000, 0.20117188]]],


         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.55078125, 0.44921875, 0.00000000, 0.00000000],
           [0.29687500, 0.43164062, 0.27148438, 0.00000000],
           [0.20605469, 0.30859375, 0.21777344, 0.26757812]],

          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.51171875, 0.48828125, 0.00000000, 0.00000000],
           [0.32617188, 0.36328125, 0.31250000, 0.00000000],
           [0.22656250, 0.20117188, 0.35156250, 0.21972656]]],


         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.28710938, 0.71484375, 0.00000000, 0.00000000],
           [0.18945312, 0.42773438, 0.38281250, 0.00000000],
           [0.37304688, 0.18554688, 0.24707031, 0.19433594]],

          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.49609375, 0.50390625, 0.00000000, 0.00000000],
           [0.47070312, 0.31054688, 0.21875000, 0.00000000],
           [0.27929688, 0.22851562, 0.28125000, 0.21289062]]],


         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.35351562, 0.64843750, 0.00000000, 0.00000000],
           [0.47070312, 0.24316406, 0.28710938, 0.00000000],
           [0.25976562, 0.22265625, 0.27539062, 0.24218750]],

          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.64843750, 0.34960938, 0.00000000, 0.00000000],
           [0.36914062, 0.40234375, 0.22656250, 0.00000000],
           [0.16894531, 0.20703125, 0.30468750, 0.31835938]]]])
    output_use_construct_mask = np.array(
        [[[[0.00000000, 0.55468750, 0.44531250, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.35742188, 0.17968750, 0.46289062],
           [0.00000000, 0.00000000, 0.50000000, 0.50000000]],
          [[0.00000000, 0.52343750, 0.47656250, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.43164062, 0.24023438, 0.33007812],
           [0.00000000, 0.00000000, 0.60937500, 0.39257812]]],
         [[[0.00000000, 0.49609375, 0.50390625, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.31835938, 0.20019531, 0.48046875],
           [0.00000000, 0.00000000, 0.44921875, 0.55078125]],
          [[0.00000000, 0.31250000, 0.68750000, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.33398438, 0.28710938, 0.37890625],
           [0.00000000, 0.00000000, 0.61718750, 0.38476562]]],
         [[[0.00000000, 0.59765625, 0.40039062, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.35546875, 0.31835938, 0.32617188],
           [0.00000000, 0.00000000, 0.56250000, 0.43945312]],
          [[0.00000000, 0.57421875, 0.42773438, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.33593750, 0.23632812, 0.42773438],
           [0.00000000, 0.00000000, 0.57031250, 0.43164062]]],
         [[[0.00000000, 0.59765625, 0.40039062, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.26171875, 0.30664062, 0.43164062],
           [0.00000000, 0.00000000, 0.53125000, 0.46875000]],
          [[0.00000000, 0.54296875, 0.45703125, 0.00000000],
           [0.00000000, 1.00000000, 0.00000000, 0.00000000],
           [0.00000000, 0.49609375, 0.27929688, 0.22558594],
           [0.00000000, 0.00000000, 0.49023438, 0.51171875]]]]
    )
    output_fp16 = np.array(
        [[[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.50000000, 0.50000000, 0.00000000, 0.00000000],
           [0.37426758, 0.41650390, 0.20935059, 0.00000000],
           [0.38720703, 0.20825195, 0.20202637, 0.20239258]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.61572266, 0.38403320, 0.00000000, 0.00000000],
           [0.31616210, 0.43920898, 0.24462890, 0.00000000],
           [0.30761720, 0.17822266, 0.31225586, 0.20178223]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.55175780, 0.44848633, 0.00000000, 0.00000000],
           [0.29687500, 0.43115234, 0.27197266, 0.00000000],
           [0.20617676, 0.30883790, 0.21740723, 0.26782227]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.51123047, 0.48901367, 0.00000000, 0.00000000],
           [0.32568360, 0.36254883, 0.31176758, 0.00000000],
           [0.22668457, 0.20166016, 0.35205078, 0.21960449]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.28662110, 0.71337890, 0.00000000, 0.00000000],
           [0.18994140, 0.42700195, 0.38305664, 0.00000000],
           [0.37329102, 0.18591309, 0.24694824, 0.19384766]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.49658203, 0.50341797, 0.00000000, 0.00000000],
           [0.47070312, 0.31079102, 0.21850586, 0.00000000],
           [0.27807617, 0.22778320, 0.28100586, 0.21301270]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.35327148, 0.64697266, 0.00000000, 0.00000000],
           [0.47045898, 0.24353027, 0.28613280, 0.00000000],
           [0.25927734, 0.22277832, 0.27563477, 0.24230957]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.64941406, 0.35034180, 0.00000000, 0.00000000],
           [0.36962890, 0.40380860, 0.22656250, 0.00000000],
           [0.16931152, 0.20654297, 0.30541992, 0.31884766]]]], dtype=np.float16)
    output_padding = np.array(
        [[[[0.18359375, 0.32617188, 0.26171875, 0.22949219],
           [0.20214844, 0.20214844, 0.18359375, 0.41210938],
           [0.24316406, 0.26953125, 0.13574219, 0.35156250],
           [0.38671875, 0.20800781, 0.20214844, 0.20214844]],
          [[0.22949219, 0.28515625, 0.25976562, 0.22558594],
           [0.32031250, 0.19921875, 0.23144531, 0.25000000],
           [0.23730469, 0.32812500, 0.18261719, 0.25195312],
           [0.30859375, 0.17773438, 0.31250000, 0.20117188]]],
         [[[0.12500000, 0.30273438, 0.30859375, 0.26367188],
           [0.22656250, 0.18359375, 0.33007812, 0.25976562],
           [0.17968750, 0.26171875, 0.16406250, 0.39453125],
           [0.20605469, 0.30859375, 0.21777344, 0.26757812]],
          [[0.22363281, 0.15527344, 0.33984375, 0.28125000],
           [0.27539062, 0.26171875, 0.19433594, 0.26953125],
           [0.23046875, 0.25585938, 0.22070312, 0.29296875],
           [0.22656250, 0.20117188, 0.35156250, 0.21972656]]],
         [[[0.20605469, 0.26757812, 0.17871094, 0.34765625],
           [0.15039062, 0.37500000, 0.30273438, 0.17089844],
           [0.13671875, 0.30664062, 0.27539062, 0.28125000],
           [0.37304688, 0.18554688, 0.24707031, 0.19433594]],
          [[0.35351562, 0.27929688, 0.20800781, 0.15917969],
           [0.20312500, 0.20605469, 0.30859375, 0.28125000],
           [0.33789062, 0.22265625, 0.15625000, 0.28320312],
           [0.27929688, 0.22851562, 0.28125000, 0.21289062]]],
         [[[0.31445312, 0.28515625, 0.19140625, 0.20800781],
           [0.17382812, 0.31835938, 0.23046875, 0.27929688],
           [0.33398438, 0.17382812, 0.20410156, 0.28710938],
           [0.25976562, 0.22265625, 0.27539062, 0.24218750]],
          [[0.27929688, 0.24804688, 0.20800781, 0.26367188],
           [0.29492188, 0.15917969, 0.32031250, 0.22558594],
           [0.31250000, 0.33984375, 0.19140625, 0.15527344],
           [0.16894531, 0.20703125, 0.30468750, 0.31835938]]]])
    output_scale_0_9 = np.array(
        [[[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.50000000, 0.50000000, 0.00000000, 0.00000000],
           [0.37109375, 0.40820312, 0.21972656, 0.00000000],
           [0.37304688, 0.21289062, 0.20703125, 0.20800781]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.60546875, 0.39453125, 0.00000000, 0.00000000],
           [0.31835938, 0.42773438, 0.25195312, 0.00000000],
           [0.30273438, 0.18457031, 0.30664062, 0.20703125]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.54687500, 0.45312500, 0.00000000, 0.00000000],
           [0.30078125, 0.42187500, 0.27734375, 0.00000000],
           [0.20996094, 0.30273438, 0.22070312, 0.26562500]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.51171875, 0.49023438, 0.00000000, 0.00000000],
           [0.32617188, 0.35937500, 0.31445312, 0.00000000],
           [0.22949219, 0.20605469, 0.34179688, 0.22265625]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.30468750, 0.69531250, 0.00000000, 0.00000000],
           [0.20214844, 0.41796875, 0.37890625, 0.00000000],
           [0.35937500, 0.19238281, 0.24902344, 0.19921875]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.49609375, 0.50390625, 0.00000000, 0.00000000],
           [0.45703125, 0.31445312, 0.22851562, 0.00000000],
           [0.27539062, 0.23046875, 0.27734375, 0.21679688]]],
         [[[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.36718750, 0.63281250, 0.00000000, 0.00000000],
           [0.45507812, 0.25195312, 0.29101562, 0.00000000],
           [0.25781250, 0.22558594, 0.27343750, 0.24316406]],
          [[1.00000000, 0.00000000, 0.00000000, 0.00000000],
           [0.63671875, 0.36523438, 0.00000000, 0.00000000],
           [0.36718750, 0.39648438, 0.23632812, 0.00000000],
           [0.17675781, 0.21191406, 0.30078125, 0.31250000]]]])
    return {
        "output_base": output_base,
        "output_padding": output_padding,
        "output_fp16": output_fp16,
        "output_use_construct_mask": output_use_construct_mask,
        "output_scale_0_9": output_scale_0_9,
    }

GOLDEN_DATA = get_golden()
GPU_DATA = get_gpu_datas()
